import torch
import sys
import os
# 添加上级目录到 sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from optimizer.alto import create_ALTO_optimizer
from optimizer.lamb import create_lamb_optimizer
from optimizer.expdt_sgd import create_esgd_optimizer
from optimizer.expdt_adam import create_eadam_optimizer
from optimizer.adaptor import OptimizerAdaptor
def get_optimizer(args, model):
    if args.opt == 'sgd':
        optimizer = torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay)
    elif args.opt == 'esgd':
        # optimizer = create_esgd_optimizer(model, lr=args.lr, betas=(args.beta, args.beta_1), alpha=args.alpha, weight_decay=args.weight_decay)
        optimizer = OptimizerAdaptor(torch.optim.SGD(model.parameters(), args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay), args.alpha, args.beta)
    elif args.opt == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta_1, args.beta_2),
                                        weight_decay=args.weight_decay)
    elif args.opt == 'eadam':
        # optimizer = create_eadam_optimizer(model, lr=args.lr, betas=(args.beta, args.beta_1, args.beta_2), alpha=args.alpha, weight_decay=args.weight_decay, eps=args.eps)
        optimizer = OptimizerAdaptor(torch.optim.Adam(model.parameters(), lr=args.lr, betas=(args.beta_1, args.beta_2),
                                        weight_decay=args.weight_decay), args.alpha, args.beta)
    elif args.opt == 'adamw':
        optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(args.beta_1, args.beta_2),
                                        weight_decay=args.weight_decay)
    elif args.opt == 'eadamw':
        optimizer = OptimizerAdaptor(torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(args.beta_1, args.beta_2),
                                        weight_decay=args.weight_decay), args.alpha, args.beta)
    elif args.opt == 'alto':
        optimizer = create_ALTO_optimizer(model, lr=args.lr, betas=(args.beta, args.beta_1, args.beta_2), alpha=args.alpha, weight_decay=args.weight_decay, eps=args.eps)
    elif args.opt == 'lamb':
        optimizer = create_lamb_optimizer(model, lr=args.lr, weight_decay=args.weight_decay)
    elif args.opt == 'elamb':
        optimizer =  OptimizerAdaptor(create_lamb_optimizer(model, lr=args.lr, weight_decay=args.weight_decay), args.alpha, args.beta)
    else:
        raise ValueError('unknown optimizer: {}'.format(args.opt))
    return optimizer